from argparse import ArgumentParser, Namespace
import pickle

import jax
from jax import jit
import jax.numpy as jnp
import numpy as np
from PIL import Image

from .model import build_thera
from .utils import make_grid, interpolate_grid

MEAN = jnp.array([.4488, .4371, .4040])
VAR = jnp.array([.25, .25, .25])
# PATCH_SIZE = 256


def process_single(source, apply_encoder, apply_decoder, params, target_shape, PATCH_SIZE):
    t = jnp.float32((target_shape[0] / source.shape[1])**-2)[None]
    coords_nearest = jnp.asarray(make_grid(target_shape)[None])
    source_up = interpolate_grid(coords_nearest, source[None])
    source = jax.nn.standardize(source, mean=MEAN, variance=VAR)[None]

    encoding = apply_encoder(params, source)
    coords = jnp.asarray(make_grid(source_up.shape[1:3])[None])  # global sampling coords
    out = jnp.full_like(source_up, jnp.nan, dtype=jnp.float32)

    for h_min in range(0, coords.shape[1], PATCH_SIZE):
        h_max = min(h_min + PATCH_SIZE, coords.shape[1])
        for w_min in range(0, coords.shape[2], PATCH_SIZE):
            # apply decoder with one patch of coordinates
            w_max = min(w_min + PATCH_SIZE, coords.shape[2])
            coords_patch = coords[:, h_min:h_max, w_min:w_max]
            out_patch = apply_decoder(params, encoding, coords_patch, t)
            out = out.at[:, h_min:h_max, w_min:w_max].set(out_patch)

    out = out * jnp.sqrt(VAR)[None, None, None] + MEAN[None, None, None]
    out += source_up
    return out


def process(source, model, params, target_shape, PATCH_SIZE, do_ensemble=True):
    apply_encoder = jit(model.apply_encoder)
    apply_decoder = jit(model.apply_decoder)

    outs = []
    for i_rot in range(4 if do_ensemble else 1):
        source_ = jnp.rot90(source, k=i_rot, axes=(-3, -2))
        target_shape_ = tuple(reversed(target_shape)) if i_rot % 2 else target_shape
        out = process_single(source_, apply_encoder, apply_decoder, params, target_shape_, PATCH_SIZE)
        outs.append(jnp.rot90(out, k=i_rot, axes=(-2, -3)))

    out = jnp.stack(outs).mean(0).clip(0., 1.)
    return jnp.rint(out[0] * 255).astype(jnp.uint8)


# def main(args: Namespace):
#     source = np.asarray(Image.open(args.in_file)) / 255.

#     if args.scale is not None:
#         if args.size is not None:
#             raise ValueError('Cannot specify both size and scale')
#         target_shape = (
#             round(source.shape[0] * args.scale),
#             round(source.shape[1] * args.scale),
#         )
#     elif args.size is not None:
#         target_shape = args.size
#     else:
#         raise ValueError('Must specify either size or scale')

#     with open(args.checkpoint, 'rb') as fh:
#         check = pickle.load(fh)
#         params, backbone, size = check['model'], check['backbone'], check['size']

#     model = build_thera(3, backbone, size)

#     out = process(source, model, params, target_shape, not args.no_ensemble)

#     Image.fromarray(np.asarray(out)).save(args.out_file)


# def parse_args() -> Namespace:
#     parser = ArgumentParser()
#     parser.add_argument('in_file')
#     parser.add_argument('out_file')
#     parser.add_argument('--scale', type=float, help='Scale factor for super-resolution')
#     parser.add_argument('--size', type=int, nargs=2,
#                         help='Target size (h, w), mutually exclusive with --scale')
#     parser.add_argument('--checkpoint', help='Path to checkpoint file')
#     parser.add_argument('--no-ensemble', action='store_true', help='Disable geo-ensemble')
#     return parser.parse_args()


# if __name__ == '__main__':
#     args = parse_args()
#     main(args)
